﻿# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

"""Precision/Recall (PR) from the paper "Improved Precision and Recall
Metric for Assessing Generative Models". Matches the original implementation
by Kynkaanniemi et al. at
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""

import torch
from . import metric_utils

# ----------------------------------------------------------------------------


def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
    assert 0 <= rank < num_gpus
    num_cols = col_features.shape[0]
    num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
    col_batches = torch.nn.functional.pad(
        col_features, [0, 0, 0, -num_cols % num_batches]
    ).chunk(num_batches)
    dist_batches = []
    for col_batch in col_batches[rank::num_gpus]:
        dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
        for src in range(num_gpus):
            dist_broadcast = dist_batch.clone()
            if num_gpus > 1:
                torch.distributed.broadcast(dist_broadcast, src=src)
            dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
    return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None


# ----------------------------------------------------------------------------


def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
    detector_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt"
    detector_kwargs = dict(return_features=True)

    real_features = (
        metric_utils.compute_feature_stats_for_dataset(
            opts=opts,
            detector_url=detector_url,
            detector_kwargs=detector_kwargs,
            rel_lo=0,
            rel_hi=0,
            capture_all=True,
            max_items=max_real,
        )
        .get_all_torch()
        .to(torch.float16)
        .to(opts.device)
    )

    gen_features = (
        metric_utils.compute_feature_stats_for_generator(
            opts=opts,
            detector_url=detector_url,
            detector_kwargs=detector_kwargs,
            rel_lo=0,
            rel_hi=1,
            capture_all=True,
            max_items=num_gen,
        )
        .get_all_torch()
        .to(torch.float16)
        .to(opts.device)
    )

    results = dict()
    for name, manifold, probes in [
        ("precision", real_features, gen_features),
        ("recall", gen_features, real_features),
    ]:
        kth = []
        for manifold_batch in manifold.split(row_batch_size):
            dist = compute_distances(
                row_features=manifold_batch,
                col_features=manifold,
                num_gpus=opts.num_gpus,
                rank=opts.rank,
                col_batch_size=col_batch_size,
            )
            kth.append(
                dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16)
                if opts.rank == 0
                else None
            )
        kth = torch.cat(kth) if opts.rank == 0 else None
        pred = []
        for probes_batch in probes.split(row_batch_size):
            dist = compute_distances(
                row_features=probes_batch,
                col_features=manifold,
                num_gpus=opts.num_gpus,
                rank=opts.rank,
                col_batch_size=col_batch_size,
            )
            pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
        results[name] = float(
            torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else "nan"
        )
    return results["precision"], results["recall"]


# ----------------------------------------------------------------------------
